package org.apache.catalina.session;

import java.io.IOException;
import java.util.Set;
import java.util.UUID;

import javax.servlet.ServletException;

import org.apache.catalina.Container;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.LifecycleState;
import org.apache.catalina.Manager;
import org.apache.catalina.Pipeline;
import org.apache.catalina.Session;
import org.apache.catalina.Valve;
import org.apache.catalina.connector.Request;
import org.apache.catalina.connector.Response;
import org.apache.catalina.session.JedisUtil.Callback;
import org.apache.catalina.valves.ValveBase;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;

import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisCluster;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
import redis.clients.jedis.JedisSentinelPool;
import redis.clients.jedis.Protocol;

/**
 * @see https://github.com/jcoleman/tomcat-redis-session-manager
 */
public class RedisSessionManager extends ManagerBase {
	static Log log = LogFactory.getLog(RedisSessionManager.class);
	static JedisPoolConfig poolConfig = new JedisPoolConfig();
	static JedisPool pool = null;
	static JedisCluster cluster = null;
	static int maxInactiveInterval = 1800;
	String host = Protocol.DEFAULT_HOST;
	int port = Protocol.DEFAULT_PORT;
	int timeout = Protocol.DEFAULT_TIMEOUT;
	String password = null;
	int database = Protocol.DEFAULT_DATABASE;
	String prefix = "", hostAndPorts = null;
	static String sentinelMaster = null;
	static JedisSentinelPool sentinel = null;
	boolean saveOnDirty = true;

	void afterRequest(String sessionId) {
		try {
			RedisSession session = (RedisSession)super.findSession(sessionId);
			if(session!=null) {
				if(saveOnDirty==false || session.dirty) {
					saveSession(session);
				}
				super.remove(session, false);
			}
		} catch (IOException e) {
			log.debug(e.getMessage());
		}
	}
	private void saveSession(final RedisSession session) {
		session.endAccess();
		String id = session.getId();
		final byte[] sid = (prefix+id).getBytes();
		Boolean save = JedisUtil.exec(pool, sentinel, cluster, new Callback<Boolean>() {
			public Boolean execute(Jedis jedis) {
				byte[] bs = JedisUtil.serialize(session);
				return bs!=null && "OK".equals(jedis.setex(sid, maxInactiveInterval, bs));
			}
		});
		log.debug(id+" save="+save);
	}
	
	protected void startInternal() throws LifecycleException {
		setState(LifecycleState.STARTING);
		
		Container container = JedisUtil.getContainer(this);
		Pipeline pipeline = container.getPipeline();
		Valve[] valves = pipeline.getValves();
		RedisSessionValve theValve = null;
		if(valves!=null && valves.length>0) {
			for(Valve valve : valves) {
				if(valve instanceof RedisSessionValve) {
					theValve = (RedisSessionValve)valve;
					break;
				}
			}
		}
		if(theValve == null) {
			theValve = new RedisSessionValve();
			pipeline.addValve(theValve);
			theValve.setContainer(container);
			theValve.setAsyncSupported(true);
		}
		theValve.setRedisSessionManager(this);
		
		JedisUtil.setClassLoader(this);
		log.info("JedisPool maxTotal="+poolConfig.getMaxTotal()+", maxIdle="+poolConfig.getMaxIdle()+", minIdle="+poolConfig.getMinIdle()+", lifo="+poolConfig.getLifo());
		Set<HostAndPort> nodes = JedisUtil.nodes(hostAndPorts);
		if(nodes!=null && nodes.size()>1) {
			if(JedisUtil.isBlank(sentinelMaster)) {
				cluster = new JedisCluster(nodes, poolConfig);
				log.info("JedisCluster nodes="+nodes);
			}else {
				Set<String> sentinels = JedisUtil.sentinels(nodes);
				log.info("JedisSentinel master="+sentinelMaster+",timeout="+timeout+",database="+database+" "+sentinels.size()+" sentinels="+sentinels);
				sentinel = new JedisSentinelPool(sentinelMaster, sentinels, poolConfig, timeout, password, database);
			}
		}else {
			pool = new JedisPool(poolConfig, host, port, timeout, password, database);
			log.info("RedisSessionManager host="+host+", port="+port+", database="+database+", timeout="+timeout+", prefix="+prefix+", maxInactiveInterval="+maxInactiveInterval);
		}
	}

	protected void stopInternal() throws LifecycleException {
		setState(LifecycleState.STOPPING);
		JedisUtil.close(pool, sentinel, cluster);
	}

	public Session createSession(String sessionId) {
		if(JedisUtil.isBlank(sessionId)) {
			sessionId = UUID.randomUUID().toString().replace("-", "");
		}
		RedisSession session = (RedisSession)createEmptySession();
		session.setId(sessionId);
		session.setNew(true);
		session.setValid(true);
		session.setCreationTime(System.currentTimeMillis());
		session.setMaxInactiveInterval(maxInactiveInterval);
		session.tellNew();
		session.setManager(this);
		session.dirty = true;
		log.debug(sessionId+" create");
		return session;
	}

	public Session createEmptySession() {
		return new RedisSession(this);
	}

	public Session findSession(final String id) throws IOException {
		Session session = super.findSession(id);
		if(session != null) {
//            synchronized(session){
//                session = super.findSession(session.getIdInternal());
//                if(session != null){
                   session.access();
                   session.endAccess();
//                }
//            }
        }else {
			synchronized (id.intern()) {
				session = super.findSession(id);
				if(session == null) {
					session = JedisUtil.exec(pool, sentinel, cluster, new Callback<RedisSession>() {
						public RedisSession execute(Jedis jedis) {
							byte[] sid = (prefix+id).getBytes();
							byte[] bs = jedis.get(sid);
							if(bs!=null) {
								RedisSession redisSession = (RedisSession)createEmptySession();
								redisSession = (RedisSession)JedisUtil.deserialize(redisSession, bs);
								if(redisSession==null) {
									jedis.del(sid);
								}
								return redisSession;
							}
							return null;
						}
					});
					if(session != null) {
						log.debug(id+" load");
						session.setManager(this);
						super.add(session);
					}
				}
			}
		}
		return session;
	}

	public void remove(Session session, boolean update) {
		super.remove(session, update);
		Boolean remove = JedisUtil.exec(pool, sentinel, cluster, new Callback.DELETE(prefix+session.getId()));
		log.debug(session.getId()+" remove="+remove);
	}

	public void load() throws ClassNotFoundException, IOException { }
	public void unload() throws IOException { }
	public int getRejectedSessions() { return 0; }
	public void setRejectedSessions(int arg0) { }
	
	public void setHost(String host) { this.host = host; }
	public void setPort(int port) { this.port = port; }
	public void setTimeout(int timeout) { this.timeout = timeout; }
	public void setPassword(String password) { this.password = password == null || (password=password.trim()).length()==0 ? null : password; }
	public void setDatabase(int database) { this.database = database; }
	public void setPrefix(String prefix) { this.prefix = JedisUtil.firstNonBlank(prefix); }
	public void setSentinelMaster(String sentinelMaster) { RedisSessionManager.sentinelMaster = sentinelMaster; }
	public void setMaxInactiveInterval(int maxInactiveInterval) { RedisSessionManager.maxInactiveInterval = maxInactiveInterval; }
	public void setSaveOnDirty(boolean saveOnDirty) { this.saveOnDirty = saveOnDirty; }
	public void setMaxTotal(int maxTotal) { poolConfig.setMaxTotal(maxTotal); }
	public void setMaxIdle(int maxIdle) { poolConfig.setMaxIdle(maxIdle); }
	public void setMinIdle(int minIdle) { poolConfig.setMinIdle(minIdle); }
	public void setLifo(boolean lifo) { poolConfig.setLifo(lifo); }
	public void setFairness(boolean fairness) { poolConfig.setFairness(fairness); }
	public void setMaxWaitMillis(long maxWaitMillis) { poolConfig.setMaxWaitMillis(maxWaitMillis); }
	public void setMinEvictableIdleTimeMillis(long minEvictableIdleTimeMillis) { poolConfig.setMinEvictableIdleTimeMillis(minEvictableIdleTimeMillis); }
	public void setSoftMinEvictableIdleTimeMillis(long softMinEvictableIdleTimeMillis) { poolConfig.setSoftMinEvictableIdleTimeMillis(softMinEvictableIdleTimeMillis); }
	public void setNumTestsPerEvictionRun(int numTestsPerEvictionRun) { poolConfig.setNumTestsPerEvictionRun(numTestsPerEvictionRun); }
	public void setEvictionPolicyClassName(String evictionPolicyClassName) { poolConfig.setEvictionPolicyClassName(evictionPolicyClassName); }
	public void setTestOnCreate(boolean testOnCreate) { poolConfig.setTestOnCreate(testOnCreate); }
	public void setTestOnBorrow(boolean testOnBorrow) { poolConfig.setTestOnBorrow(testOnBorrow); }
	public void setTestOnReturn(boolean testOnReturn) { poolConfig.setTestOnReturn(testOnReturn); }
	public void setTestWhileIdle(boolean testWhileIdle) { poolConfig.setTestWhileIdle(testWhileIdle); }
	public void setTimeBetweenEvictionRunsMillis(long timeBetweenEvictionRunsMillis) { poolConfig.setTimeBetweenEvictionRunsMillis(timeBetweenEvictionRunsMillis); }
	public void setBlockWhenExhausted(boolean blockWhenExhausted) { poolConfig.setBlockWhenExhausted(blockWhenExhausted); }
	public void setJmxEnabled(boolean jmxEnabled) { poolConfig.setJmxEnabled(jmxEnabled); }
	public void setJmxNamePrefix(String jmxNamePrefix) { poolConfig.setJmxNamePrefix(jmxNamePrefix); }
	public void setJmxNameBase(String jmxNameBase) { poolConfig.setJmxNameBase(jmxNameBase); }
	public void setHostAndPorts(String hostAndPorts) { this.hostAndPorts = hostAndPorts; }
	
	public static class RedisSessionValve extends ValveBase {
		RedisSessionManager redisSessionManager;
		public void setRedisSessionManager(RedisSessionManager redisSessionManager) {
			this.redisSessionManager = redisSessionManager;
		}
		public void invoke(Request request, Response response) throws IOException, ServletException {
			try {
				getNext().invoke(request, response);
			}finally {
				redisSessionManager.afterRequest(request.getRequestedSessionId());
			}
		}
	}
	
	public static class RedisSession extends StandardSession {
		private static final long serialVersionUID = 2L;
		transient boolean dirty;
		public RedisSession(Manager manager) {
			super(manager);
		}
		public void setAttribute(String name, Object value, boolean notify) {
			Object oldValue = getAttribute(name);
			boolean change = false;
			if(oldValue == null) {
				if(value == null) {
					//nothing change
				}else {
					change = true;
				}
			}else {
				if(oldValue == value) {
					change = true;//attr可以是map，可能已经被外部改变
				}else if(oldValue.equals(value)) {
					//nothing change
				}else {
					change = true;
				}
			}
			if(change) {
				super.setAttribute(name, value, notify);
				dirty = true;
			}
		}
		protected void removeAttributeInternal(String name, boolean notify) {
			super.removeAttributeInternal(name, notify);
			dirty = true;
		}
	}
}
